load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")
load("@rules_python//python:defs.bzl", "py_binary", "py_library")
load("@rules_python//python:pip.bzl", "compile_pip_requirements")

_GENERATE_DATA_PROBLEM = [
    "basic_lp",
    "basic_qp",
    "basic_qp2",
    "codegen",
    # TODO: fix test failure
    # "derivative_adjoint",
    "lin_alg",
    "non_cvx",
    "primal_dual_infeasibility",
    "primal_infeasibility",
    "solve_linsys",
    "unconstrained",
    "update_matrices",
]

_PROBLEM = _GENERATE_DATA_PROBLEM + [
    "csc_api",
    "large_qp",
    "settings",
]

py_library(
    name = "utils",
    srcs = glob(["utils/*.py"]),
)

[
    py_binary(
        name = "{}_generate_problem".format(p),
        srcs = ["{}/generate_problem.py".format(p)],
        main = "generate_problem.py",
        deps = [
            ":utils",
            "@osqp_test_pip_deps//numpy",
            "@osqp_test_pip_deps//scipy",
        ],
    )
    for p in _GENERATE_DATA_PROBLEM
]

[
    genrule(
        name = "{}_data".format(p),
        outs = [
            "{}/{}_data.cpp".format(p, p),
            "{}/{}_data.h".format(p, p),
        ],
        cmd = """
        $(location :{}_generate_problem)
        mkdir -p $(@D)/{}
        cp {}/{}_data.cpp {}/{}_data.h $(@D)/{}
    """.format(p, p, p, p, p, p, p),
        tools = [":{}_generate_problem".format(p)],
    )
    for p in _GENERATE_DATA_PROBLEM
]

compile_pip_requirements(
    name = "requirements",
    src = "requirements.in",
    requirements_txt = "requirements_lock.txt",
)

cc_library(
    name = "lin_alg_data_header",
    hdrs = [":lin_alg_data"],
    includes = ["lin_alg"],
    visibility = ["//visibility:private"],
)

cc_test(
    name = "lin_alg_tester",
    size = "small",
    srcs = glob(
        [
            "osqp_api.h",
            "osqp_tester.h",
            "lin_alg/lin_alg_tester.cpp",
            "lin_alg/**/test_*.cpp",
            "lin_alg/**/test_*.h",
            "utils/test_utils.cpp",
            "utils/test_utils.h",
        ],
        exclude = ["lin_alg/testcases/cuda/test_vector_cuda.cpp"],
    ) + [":lin_alg_data"],
    includes = [
        "lin_alg",
        "utils",
    ],
    deps = [
        ":lin_alg_data_header",
        "@catch2",
        "@osqp",
    ],
)

cc_test(
    name = "osqp_tester",
    size = "small",
    srcs = glob(
        [
            "osqp_api.h",
            "osqp_tester.cpp",
            "osqp_tester.h",
            # TODO: support codegen test
            # "codegen/**/*.c",
            "large_qp/large_qp_data.c",
            "large_qp/large_qp_data.h",
            "utils/test_utils.cpp",
            "utils/test_utils.h",
        ] + [
            "{}/**/test_*.cpp".format(p)
            for p in _PROBLEM
            if "solve_linsys" not in p and "codegen" not in p and "lin_alg" not in p
        ],
        exclude = [
            "basic_qp/test_cuda_io.cpp",
            "lin_alg/testcases/cuda/test_vector_cuda.cpp",
        ],
    ) + [
        ":{}_data".format(p)
        for p in _GENERATE_DATA_PROBLEM
    ],
    includes = [
        "utils",
    ] + _GENERATE_DATA_PROBLEM,
    deps = [
        "@catch2",
        "@osqp",
    ],
)
